import torch 
import torch.nn as nn 

import pandas as pd 
import numpy as np 
import csv
from bert_seq2seq.utils import load_bert
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from sympy import Integer
import re 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device is " + str(device))
word2idx = load_chinese_base_vocab("./state_dict/roberta_wwm_vocab.txt")
model_path = "./state_dict/bert_math_ques_model.bin"
model_name = "roberta"

if __name__ == "__main__":
    data = pd.read_csv("./test.csv", header=None)
    # save_data = []
    bert_model = load_bert(word2idx, model_name=model_name)
    bert_model.load_state_dict(torch.load(model_path, map_location=device))
    tokenizer = Tokenizer(word2idx)
    bert_model.to(device)
    bert_model.eval()
    err_num = 0
    
    with open("submit.csv", "w") as f_err :
        writer_err = csv.writer(f_err)
        # for i, question in tqdm(raw_data.values):
        for i, row in data.iterrows():
            print(i)
            question = re.sub('(\d+)_(\d+/\d+)', '(\\1+\\2)', row[1])
            pred_equation = bert_model.generate(question, beam_size=3, device=device)
            pred_equation = pred_equation.replace(" ", "")
            if '.' not in pred_equation:
                pred_equation = re.sub('([\d]+)', 'Integer(\\1)', pred_equation)
            try:
                pred_answer = eval(pred_equation)
               
            except:
                pred_answer = np.random.choice(21) + 1
            if '.' in pred_equation:
                if u'百分之几' in question:
                    pred_answer = pred_answer * 100
                pred_answer = round(pred_answer, 2)
                if int(pred_answer) == pred_answer:
                    pred_answer = int(pred_answer)
                if (
                    re.findall(u'多少[辆|人|个|只|箱|包本|束|头|盒|张]', question) or
                    re.findall(u'几[辆|人|个|只|箱|包|本|束|头|盒|张]', question)
                ):
                    if re.findall(u'至少|最少', question):
                        pred_answer = np.ceil(pred_answer)
                    elif re.findall(u'至多|最多', question):
                        pred_answer = np.floor(pred_answer)
                    else:
                        pred_answer = np.ceil(pred_answer)
                    pred_answer = int(pred_answer)
                pred_answer = str(pred_answer)
                if u'百分之几' in question:
                    pred_answer = pred_answer + '%'
            else:
                pred_answer = str(pred_answer)
                
                if '/' in pred_answer:
                    # if re.findall('\d+/\d+', question):
                    if u"几分之几" in question:
                        pass 
                        # a, b = pred_answer.split('/')
                        # print("a is " + str(a))
                        # print("b is " + str(b))
                        # a, b = int(a), int(b)
                        # if a > b:
                        #     pred_answer = '%s_%s/%s' % (a // b, a % b, b)
                    elif "百分之几" in question or u"出米率" in question or "出糖率" in question or "利润率" in question or "出粉率" in question \
                     or "出勤率" in question or "缺勤率" in question or "过标率" in question or "错误率" in question \
                     or ("成活率" in question and "多少棵" not in question) or ("合格率" in question and "多少个" not in question) or "出席率" in question \
                     or "发芽率" in question or "近视率" in question or "含盐率" in question or "命中率" in question :

                        a, b = pred_answer.split('/')
                        a, b = int(a), int(b)
                        pred_answer = round(a / b, 5)
                        pred_answer = str(round(pred_answer * 100)) + "%"

                    elif re.findall('\d+/\d+', question) or ":" in question :
                            a, b = pred_answer.split('/')
                            a, b = int(a), int(b)
                            pred_answer = round(a / b, 5)
                    else:
                        if re.findall(u'至少|最少', question):
                            pred_answer = np.ceil(eval(pred_answer))
                        elif re.findall(u'至多|最多', question):
                            pred_answer = np.floor(eval(pred_answer))
                        else:
                            pred_answer = np.ceil(eval(pred_answer))
                        pred_answer = str(int(pred_answer))

            writer_err.writerow([row[0], pred_answer])
            # print("equation is " + str(pred_equation) + "pred out is " + str(pred_answer) + "true res is " + str(row[2]) + "question is " + str(row[1]))
            # if str(pred_answer) != row[2]:
            #     # 说明答案错了
            #     err_num += 1
            #     print("错误个数为：" + str(err_num))
            #     writer_err.writerow([row[0], pred_answer, row[2], row[1], pred_equation, "0"])
            # else :
            #     writer_err.writerow([row[0], pred_answer, row[2], row[1], pred_equation, "1"])
        
        # for i, row in data.iterrows():
        #     print(i)
            
        #     out = bert_model.generate(row[1], beam_size=3, device=device)
        #     out = out.replace(" ", "")
        #     try:
        #         if "几分之几" in row[1] :
        #             new_equation = re.sub("(\d+)", "Integer(\\1)", out)
        #             out_v = eval(new_equation)
        #         elif "百分之几" in row[1]:
        #             out_v = float(eval(out)) * 100
        #             out_v = str(out_v) + "%"

        #         elif "得数保留整数" in row[1] or "几条船" in row[1]:
        #             out_v = round(eval(out))
        #             out_v = abs(out_v)# 避免出现负数
        #         else :
        #             out_v = float(eval(out))
        #             out_v = abs(out_v)# 避免出现负数
        #             if abs(out_v - round(out_v, 5)) < 0.0001:
        #                 out_v = round(out_v, 5)
        #             if str(out_v)[-1] == "0":
                        
        #                 out_v = str(out_v)[:-2]


        #         print("equation is " + str(out) + "pred out is " + str(out_v) + "true res is " + str(row[2]) + "question is " + str(row[1]))
        #     except Exception as e :
        #         print(e)
        #         print("解析错误out为 " + str(out) + "true res is " + str(row[2]) + "question is " + str(row[1]))
        #         out_v = -10000
        
            # if str(out_v) != row[2]:
            #     # 说明答案错了
            #     err_num += 1
            #     print("错误个数为：" + str(err_num))
            #     writer_err.writerow([row[0], out_v, row[2], row[1], out, "0"])
            # else :
            #     writer_err.writerow([row[0], out_v, row[2], row[1], out, "1"])
